// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.exploration.mv.rollup;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

import com.google.common.collect.ImmutableSet;

import java.util.Map;
import java.util.Set;

/**
 * Try to roll up function which contains distinct, if the param in function is in
 * materialized view group by dimension.
 * For example
 * materialized view def is select empid, deptno, count(salary) from distinctQuery group by empid, deptno;
 * query is select deptno, count(distinct empid) from distinctQuery group by deptno;
 * should rewrite successfully, count(distinct empid) should use the group by empid dimension in query.
 */
public class ContainDistinctFunctionRollupHandler extends AggFunctionRollUpHandler {

    public static final ContainDistinctFunctionRollupHandler INSTANCE = new ContainDistinctFunctionRollupHandler();
    public static Set<AggregateFunction> SUPPORTED_AGGREGATE_FUNCTION_SET = ImmutableSet.of(
            new Max(true, Any.INSTANCE), new Min(true, Any.INSTANCE),
            new Max(true, Any.INSTANCE).withAlwaysNullable(true),
            new Min(true, Any.INSTANCE).withAlwaysNullable(true),
            new Max(false, Any.INSTANCE), new Min(false, Any.INSTANCE),
            new Max(false, Any.INSTANCE).withAlwaysNullable(true),
            new Min(false, Any.INSTANCE).withAlwaysNullable(true),
            new Count(true, Any.INSTANCE),
            new Sum(true, Any.INSTANCE), new Sum(true, Any.INSTANCE).withAlwaysNullable(true),
            new Avg(true, Any.INSTANCE), new Avg(true, Any.INSTANCE).withAlwaysNullable(true));

    @Override
    public boolean canRollup(AggregateFunction queryAggregateFunction,
            Expression queryAggregateFunctionShuttled,
            Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair,
            Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
        Set<AggregateFunction> queryAggregateFunctions =
                queryAggregateFunctionShuttled.collectToSet(AggregateFunction.class::isInstance);
        if (queryAggregateFunctions.size() > 1) {
            return false;
        }
        for (AggregateFunction aggregateFunction : queryAggregateFunctions) {
            if (SUPPORTED_AGGREGATE_FUNCTION_SET.stream()
                    .noneMatch(supportFunction -> Any.equals(supportFunction, aggregateFunction))) {
                return false;
            }
            if (aggregateFunction.getArguments().size() > 1) {
                return false;
            }
        }
        Set<Expression> mvExpressionsQueryBased = mvExprToMvScanExprQueryBased.keySet();
        Set<Slot> aggregateFunctionParamSlots = queryAggregateFunctionShuttled.collectToSet(Slot.class::isInstance);
        if (aggregateFunctionParamSlots.stream().anyMatch(slot -> !mvExpressionsQueryBased.contains(slot))) {
            // If query use any slot not in view, can not roll up
            return false;
        }
        return true;
    }

    @Override
    public Function doRollup(AggregateFunction queryAggregateFunction,
            Expression queryAggregateFunctionShuttled, Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair,
            Map<Expression, Expression> mvExprToMvScanExprQueryBasedMap) {
        Expression argument = queryAggregateFunction.children().get(0);
        RollupResult<Boolean> rollupResult = RollupResult.of(true);
        Expression rewrittenArgument = argument.accept(new DefaultExpressionRewriter<RollupResult<Boolean>>() {
            @Override
            public Expression visitSlot(Slot slot, RollupResult<Boolean> context) {
                if (!mvExprToMvScanExprQueryBasedMap.containsKey(slot)) {
                    context.param = false;
                    return slot;
                }
                return mvExprToMvScanExprQueryBasedMap.get(slot);
            }

            @Override
            public Expression visit(Expression expr, RollupResult<Boolean> context) {
                if (!context.param) {
                    return expr;
                }
                if (expr instanceof Literal || expr instanceof BinaryArithmetic || expr instanceof Slot
                        || expr instanceof Cast) {
                    return super.visit(expr, context);
                }
                context.param = false;
                return expr;
            }
        }, rollupResult);
        if (!rollupResult.param) {
            return null;
        }
        return (Function) queryAggregateFunction.withChildren(rewrittenArgument);
    }

    private static class RollupResult<T> {
        public T param;

        private RollupResult(T param) {
            this.param = param;
        }

        public static <T> RollupResult<T> of(T param) {
            return new RollupResult<>(param);
        }

        public T getParam() {
            return param;
        }
    }
}
